package lock

import (
	"context"
	"github.com/go-redis/redis/v8"
	"ksd-social-api/commons/global"
	"time"
)

func InitRedisLock() {
	global.Redis = redis.NewClient(&redis.Options{
		Addr:     "127.0.0.1:6379",
		Password: "",
		DB:       0,
	})

	//global.RedisLock = NewRedisLock("test_key3", "test_value1", 1*time.Minute, global.Redis)
	//for i := 0; i < 10; i++ {
	//	locked, err := redisLock.TryLock(3 * time.Second)
	//	fmt.Printf("locked:%v,err:%v\n", locked, err)
	//	if i == 5 {
	//		redisLock.UnLock()
	//	}
	//}
}

// RedisLock redis实现的分布式锁
type RedisLock struct {
	key        string
	value      string // 唯一标识,一般使用uuid
	expiration time.Duration
	redisCli   *redis.Client
}

func NewRedisLock(key, value string, expiration time.Duration, cli *redis.Client) *RedisLock {
	if key == "" || value == "" || cli == nil {
		return nil
	}
	return &RedisLock{
		key:        key,
		value:      value,
		expiration: expiration,
		redisCli:   cli,
	}
}

// Lock 添加分布式锁,expiration过期时间,小于等于0,不过期,需要通过 UnLock方法释放锁
func (rl *RedisLock) Lock() (bool, error) {
	result, err := rl.redisCli.SetNX(context.Background(), rl.key, rl.value, rl.expiration).Result()
	if err != nil {
		return false, err
	}

	return result, nil
}

func (rl *RedisLock) TryLock(waitTime time.Duration) (bool, error) {
	var onceWaitTime = 20 * time.Millisecond
	if waitTime < onceWaitTime {
		waitTime = onceWaitTime
	}

	for index := 0; index < int(waitTime/onceWaitTime); index++ {
		locked, err := rl.Lock()
		if locked || err != nil {
			return locked, err
		}
		time.Sleep(onceWaitTime)
	}

	return false, nil
}

func (rl *RedisLock) UnLock() (bool, error) {
	script := redis.NewScript(`
	if redis.call("get", KEYS[1]) == ARGV[1] then
		return redis.call("del", KEYS[1])
	else
		return 0
	end
	`)

	result, err := script.Run(context.Background(), rl.redisCli, []string{rl.key}, rl.value).Int64()
	if err != nil {
		return false, err
	}

	return result > 0, nil
}

// RefreshLock 存在则更新过期时间,不存在则创建key
func (rl *RedisLock) RefreshLock() (bool, error) {
	script := redis.NewScript(`
	local val = redis.call("GET", KEYS[1])
	if not val then
		redis.call("setex", KEYS[1], ARGV[2], ARGV[1])
		return 2
	elseif val == ARGV[1] then
		return redis.call("expire", KEYS[1], ARGV[2])
	else
		return 0
	end
	`)

	result, err := script.Run(context.Background(), rl.redisCli, []string{rl.key}, rl.value, rl.expiration/time.Second).Int64()
	if err != nil {
		return false, err
	}

	return result > 0, nil
}
